import torch
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
import math
import numpy as np

from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_

from cbml_benchmark.modeling import registry


__all__ = [
    'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
    'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
    'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
    'deit_base_distilled_patch16_384',
]


class DistilledVisionTransformer(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()

        trunc_normal_(self.dist_token, std=.02)
        trunc_normal_(self.pos_embed, std=.02)

        # self.num = [64, 32, 1, 2]
        # self.num_clusters = [self.num[0] * self.num[0], self.num[1] * self.num[1], self.num[2] * self.num[2],
        #                      self.num[3] * self.num[3]]

        self.num = [64, 32, 1, 8]
        self.num_clusters = [self.num[0] * self.num[0], self.num[1] * self.num[1], self.num[2] * self.num[2],
                             self.num[3] * self.num[3]]

        self.val3 = torch.randn(self.num_clusters[3], self.embed_dim)

        self.head_dist.apply(self._init_weights)
        self.apply(self._init_centroids)

    def _init_centroids(self,m):
        self.centroids3 = nn.Parameter((self.val3).to("cuda"))

    def ra3(self,x,index):
        N, C1, H, W = x.shape
        x_flatten = x.view(N, C1, -1)
        # x_flatten = F.normalize(x_flatten, p=2, dim=1)

        sim = (torch.matmul(x_flatten.unsqueeze(0).permute(1, 0, 3, 2),
                            F.normalize(self.centroids3, p=2, dim=1).permute(1, 0).unsqueeze(0).unsqueeze(0)).permute(0,1,3,2)/ np.sqrt(self.num_clusters[index])).squeeze(1) #
        sim = torch.exp(sim)
        sim = torch.log(1+sim)
        ra = torch.zeros([N, self.num_clusters[index], C1], dtype=x.dtype, layout=x.layout, device=x.device)
        for C in range(self.num_clusters[index]):
            residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \
                       F.normalize(self.centroids3[C:C + 1, :], p=2, dim=1).expand(x_flatten.size(-1), -1, -1).permute(1,2,0).unsqueeze(0)
            residual *= sim[:, C:C + 1, :].unsqueeze(2)
            ra[:, C:C + 1, :] = residual.sum(dim=-1) / C1
        ra = F.normalize(ra, p=2, dim=2)
        ra = ra.permute(0,2,1).view(N,C1,self.num[index],self.num[index])
        return ra

    # def forward_features(self, x):
    #     # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    #     # with slight modifications to add the dist_token
    #     B = x.shape[0]
    #     x = self.patch_embed(x)
    #
    #     cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    #     dist_token = self.dist_token.expand(B, -1, -1)
    #     x = torch.cat((cls_tokens, dist_token, x), dim=1)
    #
    #     x = x + self.pos_embed
    #     x = self.pos_drop(x)
    #
    #     for blk in self.blocks:
    #         x = blk(x)
    #     x = self.norm(x)
    #     x = x.permute(0,2,1)
    #     x = x.view(x.shape[0],x.shape[1],x.shape[2],-1)
    #     # print(x.shape)
    #     x = self.ra3(x,3).contiguous()
    #     return x, x, self.centroids3, self.centroids3
    #
    # def forward(self, x):
    #     x3,x,c3,c2 = self.forward_features(x)
    #     # x3 = self.head(x3)
    #     # print(np.shape(x3))
    #     # x_dist = self.head_dist(x_dist)
    #     # if self.training:
    #     return x3,x,c3,c2
    #     # else:
    #     #     # during inference, return the average of both classifier predictions
    #     #     return (x + x_dist) / 2

    def forward_features(self, x):
        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
        # with slight modifications to add the dist_token
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        dist_token = self.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        x1 = x[:,1]
        # x = x[:,0]
        x = x.permute(0, 2, 1)
        x = x[:,:,0:1]
        x = x.view(x.shape[0],x.shape[1],x.shape[2],-1)
        x = self.ra3(x,3).contiguous()
        return x, x1,self.centroids3,self.centroids3

    def forward(self, x):
        x, x_dist,c3,c3 = self.forward_features(x)
        # print(np.shape(x),np.shape(x_dist))
        x = self.head(x)
        x_dist = self.head_dist(x_dist)
        # print(np.shape(x),np.shape(x_dist))
        if self.training:
            return x, x_dist,c3,c3
        else:
            # during inference, return the average of both classifier predictions
            return x, x_dist, c3, c3
            # return (x + x_dist) / 2, (x + x_dist) / 2,c3,c3

def rm_head(m):
    names = set(x[0] for x in m.named_children())
    target = {"head", "fc", "head_dist"}
    for x in names & target:
        m.add_module(x, nn.Identity())

def freeze(model, num_block):
    def fr(m):
        for param in m.parameters():
            param.requires_grad = False

    fr(model.patch_embed)
    fr(model.pos_drop)
    for i in range(num_block):
        fr(model.blocks[i])


@registry.BACKBONES.register('deit-t')
def deit_tiny_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model

@registry.BACKBONES.register('deit-s')
def deit_small_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@registry.BACKBONES.register('deit-b')
def deit_base_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@registry.BACKBONES.register('deit-td')
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@registry.BACKBONES.register('deit-sd')
def deit_small_distilled_patch16_224(pretrained=True, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"],strict=False)
        rm_head(model)
        freeze(model, 0)
    return model


@registry.BACKBONES.register('deit-bd')
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_patch16_384(pretrained=False, **kwargs):
    model = VisionTransformer(
        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@registry.BACKBONES.register('deit-bdl')
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
    model = DistilledVisionTransformer(
        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model
